Skip to content

Add (1-W) weight masking to TROP global method#195

Merged
igerber merged 7 commits intomainfrom
feature/trop-global-method-1w-masking
Mar 14, 2026
Merged

Add (1-W) weight masking to TROP global method#195
igerber merged 7 commits intomainfrom
feature/trop-global-method-1w-masking

Conversation

@igerber
Copy link
Owner

@igerber igerber commented Mar 8, 2026

Summary

  • Apply (1-W) masking to global method weights so model fits control data only (per paper Eq. 2)
  • Remove tau from joint optimization (unidentifiable under masking); extract post-hoc as residuals
  • Rename method='joint' to method='global' with FutureWarning deprecation alias
  • Add FISTA/Nesterov acceleration to nuclear norm solver for faster L convergence
  • Extract _solve_joint_model and _extract_posthoc_tau helpers to reduce duplication
  • Mirror all Python changes in Rust backend (6 functions)
  • Add 5 new tests: control-only weights, post-hoc residual identity, heterogeneous effects, treated outcome isolation, global alias

Methodology references

  • Method name(s): TROP (Triply Robust Panel Estimator), global estimation method
  • Paper / source link(s): Athey, Imbens, Qu & Viviano (2025), Equation 2 — (1-W)*delta weighting
  • Any intentional deviations from the source: tau is no longer a CVXPY variable (unidentified under (1-W) masking); extracted post-hoc as mean(Y - mu - alpha - beta - L) over treated cells. This matches the paper's intent and produces less-biased ATT estimates.

Validation

  • Tests added/updated: tests/test_trop.py — 5 new tests (test_global_*), 4 updated tests (test_method_in_get_params, test_method_in_set_params, test_method_set_params_joint_deprecated, test_joint_rejects_staggered_adoption)
  • Monte Carlo comparison against CVXPY reference (20 reps × 5 lambda configs): no-lowrank configs match exactly; low-rank mean |Δτ| = 0.0004 (λ_nn=0.1). All comparisons within 0.10.
  • Rust backend equivalence: 10/10 joint Rust-vs-Python parity tests pass
  • Full test suite: 90 TROP + 78 Rust backend + 1438 other tests pass

Security / privacy

  • Confirm no secrets/PII in this PR: Yes

Generated with Claude Code

igerber and others added 4 commits March 8, 2026 17:46
…ISTA acceleration

The soft-thresholding threshold for the L matrix was λ_nn/max(δ) when the
correct value is λ_nn/(2·max(δ)), derived from the Lipschitz constant
L_f = 2·max(δ) of the quadratic gradient. This over-shrinking caused
singular values to be reduced too aggressively.

Fix applied to all four code paths: Python joint, Python twostep, Rust
joint, Rust twostep. Also adds FISTA/Nesterov acceleration to the twostep
inner solver for faster L convergence (O(1/k²) vs O(1/k)).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ghts

The proximal threshold was hardcoded as λ/(2) which is only correct when
W_max=1. Changed to λ/(2·W_max) to match the joint solver and Rust backend.
Added test with non-uniform weights and updated REGISTRY.md algorithm docs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…rings

- Add conditional threshold when W_max==0 to prevent ZeroDivisionError,
  matching Rust backend behavior (trop.rs:665)
- Update Python and Rust docstrings to reflect correct FISTA/Nesterov
  acceleration formulas (L_f = 2·max(W), η = 1/(2·max(W)))
- Add regression test for all-zero weights edge case

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Align the TROP global method with the paper's Eq. 2 by adding (1-W)
masking so the model is fit on control data only, then extracting
treatment effects post-hoc as residuals (tau_it = Y - mu - alpha - beta - L).

Key changes:
- Apply (1-W) masking in _compute_joint_weights, zeroing treated cells
- Remove tau from the joint solvers (no longer identifiable under masking)
- Extract per-observation treatment effects post-hoc; ATT = mean(tau_it)
- Add FISTA/Nesterov acceleration to the nuclear norm solver (O(1/k²))
- Rename method='joint' to method='global' with FutureWarning deprecation
- Extract _solve_joint_model and _extract_posthoc_tau helpers to reduce duplication
- Mirror all changes in Rust backend

Monte Carlo validation (20 reps × 5 configs) shows:
- No-lowrank configs: exact match with CVXPY reference (|Δτ| = 0)
- Low-rank configs: mean |Δτ| = 0.0004 (λ_nn=0.1), 0.026 (λ_nn=0.01)
- 100% of comparisons within 0.10 of reference

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@igerber
Copy link
Owner Author

igerber commented Mar 8, 2026

/ai-review

@github-actions
Copy link

github-actions bot commented Mar 8, 2026

🔁 AI review rerun (requested by @igerber)

Head SHA: bb6bf5c54ce7c2898fb6497167fc4c98bf92a525


Overall Assessment

⚠️ Needs changes

Executive Summary

  • method="global" (and deprecated joint) is still being documented as if it were the paper’s Eq. 2 / Eq. 5 estimator, but the paper’s multiple-treated TROP procedure remains the per-treated-cell estimator in Eq. 12 / Algorithm 2. The new path may be a reasonable approximation, but it is not paper-exact as currently described. citeturn17open0turn19open0turn21open0
  • The Python low-rank global solver returns mu/alpha/beta from the iteration before the final L update, while the Rust path re-solves on the converged L. The new post-hoc ATT extraction uses those stale Python coefficients directly.
  • The Python bootstrap fallback now appends np.nan ATT draws when a resample has no finite treated cells, so one invalid draw can poison the entire SE/inference path instead of being dropped.
  • The Methodology Registry is internally inconsistent after the FISTA change: one updated section documents FISTA, another still says the Python joint solver uses a single proximal step.

Methodology

  • Severity: P1. Affected method: TROP method="global" / deprecated joint. Impact: the new public description in diff_diff/trop.py:74, diff_diff/trop.py:1228, docs/methodology/REGISTRY.md:1182, and docs/methodology/REGISTRY.md:1230 overstates paper alignment. Eq. 2 is the masked single-treated-cell fit, while the paper’s multiple-treated estimator averages per-treated-cell fits via Eq. 12 / Algorithm 2; alternative identifying assumptions are discussed separately, not as the main estimator. Concrete fix: label global explicitly as a library-specific approximation/adaptation, scope the Eq. 2 reference to the (1-W) masking / residual extraction idea only, and point paper-faithful users to method="twostep". citeturn17open0turn19open0turn21open0
  • Severity: P1. Affected method: global bootstrap variance / SE path. Impact: diff_diff/trop.py:704 returns np.nan ATT when a resample has no finite treated cells, but the Python fallback blindly appends that draw at diff_diff/trop.py:1404 and then computes np.std at diff_diff/trop.py:1422. One bad resample can therefore make the whole SE NaN. The Rust path instead drops such iterations at rust/src/trop.rs:1755. Concrete fix: skip non-finite bootstrap ATT draws in Python, warn on dropped iterations, and match Rust’s <2 valid draws => NaN SE behavior instead of returning 0.0 for the empty case.

Code Quality

  • Severity: P1. Affected method: global low-rank solver with finite lambda_nn. Impact: diff_diff/trop.py:943 solves mu/alpha/beta against the current L, then diff_diff/trop.py:976 updates L and returns immediately on convergence. diff_diff/trop.py:1226 and diff_diff/trop.py:1467 then use those stale coefficients for post-hoc ATT extraction. The Rust mirror already performs a final re-solve on the converged L at rust/src/trop.rs:1406. Concrete fix: after the outer loop, run one final _solve_joint_no_lowrank(Y_safe - L, delta_masked) before returning, and add a finite-lambda_nn Python/Rust parity regression.

Performance

  • No material performance finding from the diff alone.

Maintainability

  • No separate maintainability finding beyond the methodology/documentation drift below.

Tech Debt

  • No separate tech-debt finding beyond the need to keep the paper-exact and approximation paths clearly distinguished.

Security

  • No security issues or secret exposure found in the diff.

Documentation/Tests

  • Severity: P2. Impact: the registry’s generic optimization section still says “the Python joint solver uses a single proximal gradient step per outer alternating iteration” at docs/methodology/REGISTRY.md:1084, but the changed code now uses an inner FISTA loop in both diff_diff/trop.py:953 and rust/src/trop.rs:1365. The added global-mode tests starting at tests/test_trop.py:3256 also do not cover the Python/Rust parity gap above or the bootstrap “no finite treated cells” edge case. Concrete fix: update the registry text to match the implemented FISTA loop and add regressions for finite-lambda_nn global parity plus bootstrap resamples with all-NaN treated outcomes.

Static review only. I could not run numpy-based reproductions in this shell because numpy is not installed here.

…isoning

- Add final re-solve after outer loop convergence in _solve_joint_with_lowrank
  to ensure mu/alpha/beta are consistent with converged L (matches Rust)
- Filter NaN ATT draws in bootstrap fallback with np.isfinite check
- Clarify global method docs as adaptation of Eq. 2 masking principle,
  not paper-exact Algorithm 2
- Update FISTA documentation to reflect both solvers now use acceleration

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@igerber
Copy link
Owner Author

igerber commented Mar 8, 2026

/ai-review

@github-actions
Copy link

github-actions bot commented Mar 8, 2026

🔁 AI review rerun (requested by @igerber)

Head SHA: ab913f1ba58543bb341d783b01895e92c0826d1c


Overall Assessment

⚠️ Needs changes

Executive Summary

  • Affected methods: TROP method='global' / deprecated method='joint', plus the weighted nuclear-norm solver path used by global and twostep.
  • The Python global estimator is aligned with the updated registry on (1-W) masking and post-hoc residual ATT extraction.
  • The Rust global backend is not equivalent to the registry/Python weighted-LS solve after this change; its new stopping rule can exit before fixed effects converge, which affects Rust LOOCV and bootstrap.
  • The new post-hoc ATT path has a missing-treated edge case: it drops non-finite treated outcomes from ATT but still reports/infer using all D==1 cells, and it does not explicitly handle the zero-valid-treated case.
  • Bootstrap undefined-SE behavior is now backend-dependent (0.0 in Python fallback vs NaN in Rust).
  • Docs/tests were only partially updated: the changed API docs still describe joint as a scalar-τ method, and no Rust parity coverage was added for the new global semantics.

Methodology

  • P1 Impact: The Rust replacement for the global weighted-LS solve is no longer the same estimator as the registry/Python path. The registry says the no-lowrank global fit is weighted least squares, and Python implements it with exact lstsq; Rust now stops solve_joint_no_lowrank() as soon as mu stops moving, without checking alpha/beta. That is not a valid convergence criterion for two-way FE, so Rust-enabled LOOCV/bootstrap can select different lambdas or SEs than the exact Python fit. See docs/methodology/REGISTRY.md:L1217-L1227, rust/src/trop.rs:L1250-L1303, rust/src/trop.rs:L1345-L1412, rust/src/trop.rs:L1455-L1463, rust/src/trop.rs:L1739-L1753. Concrete fix: implement the same exact WLS solve in Rust as Python, or at minimum require convergence on max change across mu, alpha, and beta before returning.
  • P1 Impact: The new post-hoc ATT path silently changes the sample behind the estimand when treated outcomes are missing. _extract_posthoc_tau() averages only valid_treated = (D == 1) & isfinite(Y), but _fit_joint() still sets n_treated_obs and df_trop from all D == 1 cells, and it has no explicit empty-result guard when valid_treated.sum() == 0. That is a new empty-result path and an inference mismatch against the registry’s “ATT = mean over treated observations.” See docs/methodology/REGISTRY.md:L1201-L1205, diff_diff/trop.py:L699-L705, diff_diff/trop.py:L1070-L1073, diff_diff/trop.py:L1253-L1269. Concrete fix: compute n_valid_treated, raise or warn when it is zero, and use that count consistently for ATT metadata and df_trop (or store both total treated and valid treated explicitly).

Code Quality

  • P2 Impact: Undefined bootstrap SE handling is now backend-dependent. The Python fallback drops non-finite ATT draws but still returns se=0.0 when zero bootstrap fits succeed, while Rust returns NaN when fewer than two draws exist. That leaves results.se inconsistent across backends for the same estimator. See diff_diff/trop.py:L1408-L1427, rust/src/trop.rs:L1755-L1789. Concrete fix: make the Python fallback match Rust (NaN for <2 successful resamples) and add a regression test forcing 0/1 successful bootstrap draws.

Performance

  • No material diff-specific findings.

Maintainability

  • No material diff-specific findings.

Tech Debt

  • No material diff-specific findings.

Security

  • No material findings in the diff.

Documentation/Tests

  • P2 Impact: The changed API docs were only partially updated. The same file that now documents method='global' and heterogeneous residual-based effects still has a later example describing method='joint' as a “single scalar treatment effect via weighted LS,” which is now false and also recommends the deprecated API. The new regression tests are also Python-only even though the PR mirrors the estimator changes into Rust. See docs/api/trop.rst:L122-L149, docs/api/trop.rst:L214-L234, tests/test_trop.py:L3256-L3425, tests/test_rust_backend.py:L1272-L1683. Concrete fix: rewrite the example to method='global' with residual-based ATT wording, and add Rust/Python parity tests for control-only masking, post-hoc residual ATT, and treated-outcome isolation.

igerber added a commit that referenced this pull request Mar 14, 2026
Address five failure modes observed after gpt-5.4 upgrade (PRs #192, #194,
#195): documented deviations flagged as blockers, deferred work not accepted,
moving goalposts on re-review, undefined approval criteria, and valid
implementation choices treated as methodology errors.

Changes to pr_review.md:
- Exempt REGISTRY.md-documented deviations from P0/P1 (classify as P3)
- Add implementation choice exception for valid numerical approaches
- Add Deferred Work Acceptance section honoring TODO.md tracking
- Add Assessment Criteria with objective verdicts and mitigation rules
- Add Re-review Scope rules to prevent oscillation between rounds

Changes to ai_pr_review.yml:
- Add step to fetch previous AI review comment for re-review context
- Inject prior review findings into compiled prompt on /ai-review reruns

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@igerber
Copy link
Owner Author

igerber commented Mar 14, 2026

/ai-review

@github-actions
Copy link

🔁 AI review rerun (requested by @igerber)

Head SHA: ab913f1ba58543bb341d783b01895e92c0826d1c


Overall Assessment

⚠️ Needs changes

Executive Summary

  • Affected methods: TROP method='global' / deprecated method='joint', plus the weighted nuclear-norm low-rank solver used by global and twostep.
  • The core Python rewrite to (1-W) masking plus post-hoc residual ATT extraction is now documented in the registry and looks aligned; I do not count that behavior change itself as a defect.
  • Two prior P1 findings remain open: the Rust global weighted-LS solve is still not the documented/Python estimator, and missing treated outcomes are still dropped from ATT while df/metadata are computed from all D==1 cells.
  • The bootstrap path is safer than before because inference now goes through safe_inference(), but the global path still reports se=0.0 when zero bootstrap fits succeed.
  • Docs/tests are still only partially updated: one API example still describes deprecated joint scalar-τ semantics, and the new global-method regression coverage is still Python-only.

Methodology
Cross-checking docs/methodology/REGISTRY.md:L1070-L1093 and docs/methodology/REGISTRY.md:L1179-L1276 against diff_diff/trop.py:L652-L719 and diff_diff/trop.py:L2046-L2145, the Python (1-W) masking rewrite and the proximal-threshold/FISTA update look documented and defensible.

  • Severity P1. Impact: the Rust global no-lowrank solve is still not the documented estimator. The registry now specifies exact weighted least squares for the no-lowrank global method in docs/methodology/REGISTRY.md:L1217-L1229, and Python implements that with np.linalg.lstsq in diff_diff/trop.py:L799-L881. Rust still uses an iterative FE loop in rust/src/trop.rs:L1207-L1412 that stops when only mu stabilizes, and that path still drives Rust LOOCV/bootstrap in diff_diff/trop.py:L1127-L1145 and diff_diff/trop.py:L1329-L1363. That is an undocumented methodology deviation, not just an implementation choice, because Rust-enabled runs can pick different λs/SEs than the documented Python estimator. Concrete fix: replace solve_joint_no_lowrank() with the same exact weighted-LS solve used in Python, or at minimum require convergence on the max change across mu, alpha, and beta, then add Rust/Python parity tests for both no-lowrank and lowrank global fits.
  • Severity P1. Impact: the new post-hoc ATT path still changes the treated sample when treated outcomes are missing, but df and metadata still use all D==1 cells. _extract_posthoc_tau() averages only finite treated outcomes in diff_diff/trop.py:L699-L719, while _fit_joint() still computes n_treated_obs from all treated cells in diff_diff/trop.py:L1068-L1073 and uses that count for df_trop and results metadata in diff_diff/trop.py:L1253-L1269. That is inconsistent with the registry ATT definition in docs/methodology/REGISTRY.md:L1201-L1205, and it leaves the zero-valid-treated path as a silent NaN result instead of an explicit empty-result error/warning. Concrete fix: compute n_valid_treated = np.sum((D == 1) & np.isfinite(Y)), raise or warn when it is zero, and use n_valid_treated consistently for ATT metadata, df_trop, and any bootstrap/result summaries (or expose both total and valid treated counts explicitly).

Code Quality

  • Severity P2. Impact: the global bootstrap path still returns se=0.0 when zero bootstrap fits succeed in both the Rust wrapper and Python fallback, via diff_diff/trop.py:L1355-L1363 and diff_diff/trop.py:L1419-L1428. safe_inference() prevents bogus t/p/CI, but results.se itself is still a fabricated value instead of an undefined one. Concrete fix: return np.nan for <2 successful bootstrap draws in both branches and add a regression test that forces 0 and 1 successful global bootstrap fits.

Performance
No material diff-specific findings.

Maintainability
No material diff-specific findings beyond the documentation/test drift noted below.

Tech Debt
No material diff-specific findings. There is no relevant mitigating entry under TODO.md:L38-L62, so the open TROP issues remain unmitigated.

Security
No material findings in the diff.

Documentation/Tests

  • Severity P2. Impact: the API/docs/tests update is still partial. docs/api/trop.rst:L214-L234 still presents method='joint' as a single-scalar weighted-LS estimator, while the new global-method regression cases in tests/test_trop.py:L3256-L3425 still have no Rust-side analogue in tests/test_rust_backend.py:L1152-L1685. Concrete fix: rewrite the API example to method='global' with residual-based ATT wording, and add Rust/Python parity coverage for control-only masking, post-hoc residual ATT extraction, and treated-outcome isolation.

Path to Approval

  1. Make the Rust global FE solve methodologically equivalent to the documented/Python exact weighted-LS solve, and add a parity test that compares Rust-selected λs / fitted (mu, alpha, beta, L) against Python for both lambda_nn=inf and finite lambda_nn.
  2. Introduce n_valid_treated in the global path, use it for df_trop and reported metadata, and add tests for both partial-missing-treated and all-missing-treated cases so the zero-valid-treated path errors or warns explicitly.

…bootstrap SE

- Rust solve_joint_no_lowrank: check convergence across all params (mu, alpha,
  beta), not just mu — fixes premature termination and Rust/Python divergence
- Global + twostep: use n_valid_treated (finite outcomes only) for df_trop and
  results.n_treated_obs; skip NaN Y in twostep loop to prevent NaN poisoning
- Return np.nan (not 0.0) SE when <2 bootstrap draws succeed (all 3 paths)
- Update API docs: method='joint' example → method='global'
- Fix stale FISTA reference in REGISTRY.md global section
- Add edge case docs for partial/all-NaN treated outcomes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@igerber
Copy link
Owner Author

igerber commented Mar 14, 2026

/ai-review

@github-actions
Copy link

🔁 AI review rerun (requested by @igerber)

Head SHA: 9d4c80caca3f1ef9489a2f6614b67f41af00c88f


Overall Assessment

✅ Looks good

Executive Summary

Methodology

No unmitigated P0/P1 findings.

  • Severity P3. Impact: the registry now documents the global estimator correctly at a high level, but the implementation subsection is still Python-specific (np.linalg.lstsq) and does not use the project’s explicit **Note:** convention for documented adaptations, even though Rust reaches the same weighted-LS objective through a different numerical route (docs/methodology/REGISTRY.md:L1215-L1229, rust/src/trop.rs:L1207-L1315). That can cause avoidable future methodology-review churn, but it is not a correctness defect under the project rule that equivalent numerical approaches are acceptable. Concrete fix: make the no-lowrank description backend-agnostic and/or add a **Note:** label explicitly marking the global path as a documented adaptation.

Code Quality

No material diff-specific findings.

Performance

No material diff-specific findings.

Maintainability

No material diff-specific findings.

Tech Debt

No material diff-specific findings.

Security

No material findings in the diff.

Documentation/Tests

  • Severity P3. Impact: the new low-rank Rust parity regression only asserts ATT and unit effects; it does not assert time_effects or the low-rank/fitted matrix, which are the quantities most directly changed by the corrected threshold/FISTA rewrite (tests/test_rust_backend.py:L1755-L1818). A backend drift in beta or L could therefore slip through while ATT on this fixture still matches. Concrete fix: extend the finite-lambda_nn parity test to compare time_effects and either factor_matrix or the full fitted counterfactual matrix between Rust and Python.

The test patches _fit_joint_with_fixed_lambda to force failures, but
on CI with Rust available, the Rust bootstrap path runs instead of the
Python fallback. Disable Rust backend in the test to exercise the
Python return path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@igerber igerber merged commit 34f6c22 into main Mar 14, 2026
9 of 10 checks passed
@igerber igerber deleted the feature/trop-global-method-1w-masking branch March 14, 2026 23:19
igerber added a commit that referenced this pull request Mar 14, 2026
Address five failure modes observed after gpt-5.4 upgrade (PRs #192, #194,
#195): documented deviations flagged as blockers, deferred work not accepted,
moving goalposts on re-review, undefined approval criteria, and valid
implementation choices treated as methodology errors.

Changes to pr_review.md:
- Exempt REGISTRY.md-documented deviations from P0/P1 (classify as P3)
- Add implementation choice exception for valid numerical approaches
- Add Deferred Work Acceptance section honoring TODO.md tracking
- Add Assessment Criteria with objective verdicts and mitigation rules
- Add Re-review Scope rules to prevent oscillation between rounds

Changes to ai_pr_review.yml:
- Add step to fetch previous AI review comment for re-review context
- Inject prior review findings into compiled prompt on /ai-review reruns

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant